Skip to content

[Pytorch][Common] Hybrid quantization#2817

Open
negvet wants to merge 50 commits into
NVIDIA:mainfrom
negvet:hybrid_quantization
Open

[Pytorch][Common] Hybrid quantization#2817
negvet wants to merge 50 commits into
NVIDIA:mainfrom
negvet:hybrid_quantization

Conversation

@negvet

@negvet negvet commented Mar 31, 2026

Copy link
Copy Markdown
Collaborator

Description

Hybrid (per-direction) quantization. Hybrid means rowwise/colwise can use different formats via CustomRecipe(qfactory).
This is an experimental feature.
The main problem that it tries to solve is that precision requirements are non-uniform.

Current recipes set one format for both rowwise and colwise directions.
Hybrid quantization enables, e.g. MXFP8 fwd and NVFP4 bwd (or vice versa) or any other valid combination. No need for a hardcoded recipe for every combination.

Composer-style (Composer 2 paper) grouped GEMM recipe, e.g. row-scaled NVFP4 fwd + MXFP8 bwd:

# CustomRecipe calls quantization_factory(role) for each quantized tensor
# Factory chooses formats

def hybrid_factory(role):
    is_grouped_linear = role is not None and role.module_type == "grouped_linear"
    is_linear = role is not None and role.module_type == "linear"

    if is_grouped_linear and role.tensor_type == "input":
        return HybridQuantizer(
            rowwise_quantizer=NVFP4Quantizer(row_scaled_nvfp4=True, ...),
            columnwise_quantizer=MXFP8Quantizer(...),
        )

    if is_grouped_linear and role.tensor_type == "weight":
        return HybridQuantizer(
            rowwise_quantizer=NVFP4Quantizer(...),
            columnwise_quantizer=MXFP8Quantizer(...),
        )

    if is_grouped_linear and role.tensor_type == "grad_output":
        return MXFP8Quantizer(...)

    if is_linear:
        return MXFP8Quantizer(...)

    return MXFP8Quantizer(...)

recipe = CustomRecipe(qfactory=hybrid_factory)
with autocast(recipe=recipe):
    y = model(x)

By default, the above factory uses columnwise_source="original", so MXFP8 backward operands are quantized from the original high-precision tensor. Use columnwise_source="rowwise_dequantized" when the backward operand should be derived from the dequantized rowwise NVFP4 forward value.

C++ optimizations (fusions, etc.) will come as standalone PRs. cc @kainzhong

TODO:

  • Convergence of base (non-hybrid) recipes
  • HybridFloat8BlockScaling is xfailed under FSDP2 because dim-0 shards can split 128-row block-scale tiles, producing all-gathered scale buffers whose shape does not match the global tensor.
  • Delayed scaling
  • Mid-training recipe change

Follow-up issue tracker #3158.

Integration

Ecosystem integration (all functional, unit-tested):

  • [Done] quantized_model_init
  • [Done] FSDP2 (TODO: optimize communication buffers)
  • [Done] CPU offloading
  • [Done] Activation recomputation
  • [Done] TP/SP (TODO: enable quantized AG)

Megatron-LM integration status:

  • [Done] 1 GPU baseline
  • [Done] DP + distributed optimizer
  • [TODO] quantized_model_init + --fp{4,8}-param-gather + dist opt (persistent low-precision params via quantized_model_init + sharded-master FP32 → quantized cast via quantize_master_weights.)
    - [Done] Per-tensor Float8 hybrid (delayed and/or current, any per-direction combination
    including same-format, cross-format Float8, single-direction)
    - [TODO] Per-block hybrid sub-quantizers (MXFP8, NVFP4, Float8Blockwise) — each rejected per-direction by quantize_master_weights; unblocker is TE-side cast-helper / kernel.
  • [TODO] Megatron-FSDP + --fp{4,8}-param-gather (fix private attribute access)
  • [TODO] Torch FSDP2 + --fp{4,8}-param-gather
    - [Done] TE-side hybrid FSDP2 path works end-to-end for Float8 / MXFP8 / Float8Blockwise sub-storages (TODO: need some minor MLM update)
    - [TODO] NVFP4 sub-storage FSDP2 hooks
  • [Done] Activation recompute
  • [Done] CPU offload
  • [Done] TP/SP/PP
  • [Done] MoE + EP + grouped GEMM (qwen3 MoE; _hybrid_split_quantize under Megatron MoE)

Review

Total diff +14000
New hybrid source (hybrid_tensor.py, hybrid_tensor_storage.py, identity_tensor.py, identity_tensor_storage.py) ~1800
Adjacent modifications ~1500
Tests are the rest (~10K)

Suggested reading order

  1. Foundation — 7553e6a: Python containers + quantize/gemm dispatch/unwrap
  • tensor/hybrid_tensor.py — HybridQuantizer + HybridQuantizedTensor
    -columnwise_source controls whether columnwise quantization uses the original input or the rowwise-dequantized value.
  • tensor/storage/hybrid_tensor_storage.py
  • cpp_extensions/gemm.py — _unwrap_hybrid_A/B
  • common/transpose/quantize_transpose_square_blockwise.cu - Block FP8 columnwise-only null-checks
  • Module hooks in module/{base,grouped_linear,layernorm_linear,layernorm_mlp}.py
  • Tests: TestHybridQuantizer*, TestHybridGemmBitwiseIdentical* (proves zero-overhead vs vanilla recipes when both formats match), TestHybridDirectionUnwrap*, TestHybridGroupedLinear*

1.1 Identity passthrough — b99277a

  • tensor/identity_tensor.py and tensor/storage/identity_tensor_storage.py — IdentityQuantizer / IdentityTensor high-precision passthrough
  • custom_recipes/quantization_factory_zoo.py — examples for high-precision fwd/bwd directions and columnwise_source="rowwise_dequantized"
  • Tests: test_identity_quantizer.py plus hybrid tests covering Identity inside HybridQuantizer
  1. quantized_model_init + FusedAdam — f80f5d0
  • hybrid_tensor.py::HybridQuantizer.update_quantized — delegates to each sub-quantizer; unblocks workspace-cache quantize_() and FusedAdam writeback
  • module/base.py workspace-cache invalidation
  • Tests: TestHybridQuantizedModelInit, TestHybridFusedAdam, TestHybridQuantizedParamsEndToEnd, TestHybridCheckpoint, TestQuantizedParamsEquivalence*
  1. FSDP2 support — 2185b30
  • New base FSDP2 buffer protocol on QuantizedTensorStorage: fsdp_buffer_fields / fsdp_extract_buffers / fsdp_assign_gathered. Generic, reusable beyond hybrid.
  • Per-format overrides on Float8TensorStorage (direction-aware) and MXFP8TensorStorage (trips/re-applies scale alignment padding around the gather)
  • hybrid_tensor.py::fsdp_pre/post_all_gather + torch_dispatch for the FSDP2 op set (view, split, as_strided, slice, copy_, new_zeros, clone, detach)
  • Non-safety in float8_tensor.py and mxfp8_tensor.py for single-direction sub-storages (columnwise-only on Hopper/L40)
  • Tests: TestHybridTorchDispatchFSDP2Ops, TestHybridFsdpPreAllGatherProtocol, TestHybridFsdpRoundtrip (bitwise-exact against manual all_gather(dequantize(shard))), plus tests/pytorch/distributed/fsdp2_tests/
  1. CPU offloading — 103fffe
  • hybrid_tensor_storage.py::clear() (v1 path) + prepare_for_saving / restore_from_saved chain (v2 path)
  • hybrid_tensor.py::detach() re-wraps each sub-storage via make_like (required by cpu_offload_v2's detach → prepare_for_saving pattern; sharing sub-storage objects would null-out fields on the original)
  • TestHybridCpuOffloadPushPop, plus updates to test_cpu_offloading*.py
  1. Activation recomputation — 16fb371
  • Uses existing QuantizedTensorStorage::prepare_for_saving / restore_from_saved protocol, preserving ordering across both sub-storages
  • Tests: 20 bitwise tests in TestHybridActivationRecompute
  1. TP/SP — a50fd63
  • hybrid_tensor.py::HybridQuantizer.supports_only_rowwise_all_gather — overrides to handle the NVFP4 columnwise-dequantize gap in the BF16 fallback path
  • distributed.py::gather_along_first_dim — hybrid branch re-quantizes with both directions after AG (since hybrid has no _create_transpose synthesis path)
  • Tests: 9 distributed tests in run_hybrid_tp_sp.py / test_hybrid_tp_sp.py
  1. Megatron-LM integration — a164cd3
  • tensor/utils.py::_route_hybrid_to_buckets — per-direction dispatch for quantize_master_weights: iterates both sub-storages, routes each independently into the per-format bucket matching its own sub-quantizer type
  • Hybrid branches in replace_raw_data and post_all_gather_processing
  • Today: per-tensor Float8 sub-quantizers (delayed + current) work in any per-direction combination. Per-block sub-quantizers raise per-direction with in-code TODOs naming the unblocker.
  • Tests: TestHybridQuantizeMasterWeights, TestHybridPostAllGatherProcessing

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps

greptile-apps Bot commented Mar 31, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces hybrid (per-direction) quantization to TransformerEngine, allowing rowwise and columnwise GEMM operands to use different quantization formats (e.g., NVFP4 forward + MXFP8 backward) via a CustomRecipe(qfactory) factory pattern. It also adds a high-precision IdentityQuantizer/IdentityTensor passthrough for unquantized directions.

  • New types: HybridQuantizer, HybridQuantizedTensor, HybridQuantizedTensorStorage (two sub-storages, one per direction) plus IdentityQuantizer/IdentityTensor (no-op quantizer for HP passthrough), with full support for quantize_master_weights, FSDP2, CPU offloading, activation recompute, and TP/SP.
  • GEMM dispatch (gemm.py): _unwrap_hybrid_A/B extract the direction-appropriate sub-storage before cuBLAS dispatch; _reject_unsupported_output_quantizer guards against hybrid/identity output quantizers being silently consumed.
  • Backend enablement: quantize_transpose_square_blockwise.cu gains null-pointer guards for output_c and tile_scales_inv_c, enabling columnwise-only mode needed by hybrid's Float8Block sub-storage; Float8TensorStorage gets direction-aware fsdp_buffer_fields and fsdp_assign_gathered to handle Hopper/L40 columnwise-only sub-storages; the master-weight update path handles transpose-only Float8 shards via a scatter loop.

Confidence Score: 5/5

Safe to merge — the new hybrid and identity quantization paths are well-isolated and thoroughly tested; existing FP8 and MXFP8 paths are unchanged in behavior.

The core dispatch logic (GEMM unwrapping, grouped-linear split-quantize, FSDP2 pre/post all-gather, master-weight update) follows the same contracts as the existing per-format paths. Previous review concerns about mixed hybrid/None quantizer lists, make_empty exception safety, and FSDP2 direction-awareness for Float8 sub-storages are all addressed in this revision. The two remaining notes are defensive/forward-compatibility concerns that do not affect any currently exercised code path.

transformer_engine/pytorch/cpp_extensions/gemm.py (None passthrough from dropped hybrid sub-storage) and transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py (fsdp_extract_buffers single-direction coverage vs fsdp_buffer_fields bidirectional reporting)

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/hybrid_tensor.py New file: HybridQuantizer + HybridQuantizedTensor with full FSDP2, CPU offload, activation recompute, TP/SP, and checkpoint support. fsdp_post_all_gather correctly calls _sync_usage after each sub-storage reassignment.
transformer_engine/pytorch/tensor/storage/hybrid_tensor_storage.py New file: HybridQuantizedTensorStorage mixin. prepare_for_saving/restore_from_saved correctly chain both sub-storages in declaration order for activation recompute protocol.
transformer_engine/pytorch/tensor/identity_tensor.py New file: IdentityQuantizer/IdentityTensor passthrough. fsdp_pre/post_all_gather and torch_dispatch ops cover the expected paths; _wrap_data_view correctly wraps all shape ops.
transformer_engine/pytorch/cpp_extensions/gemm.py _unwrap_hybrid_A/B correctly maps layout flags to sub-storage direction; _reject_unsupported_output_quantizer guards output. However, when a sub-storage is None (dropped via update_usage) the unwrap silently returns None, causing an obscure C++ crash rather than a clear Python error.
transformer_engine/pytorch/tensor/utils.py _route_hybrid_to_buckets correctly decomposes hybrid tensors per-direction; _update_transpose_only_float8_flat_fragment scatter loop correctly maps row-major FP32 shard to the columnwise FP8 transpose layout; _cast_master_weights_to_identity handles both FSDP and non-FSDP paths.
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py fsdp_buffer_fields now direction-aware (returns _transpose when _data is None); fsdp_assign_gathered clears _transpose_invalid when _transpose is the gathered field — fixes the columnwise-only Hopper/L40 FSDP2 bug from previous review threads.
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py New FSDP2 buffer protocol for Float8Block. fsdp_extract_buffers handles one direction at a time but fsdp_buffer_fields can return 4 field names when both directions are populated — a protocol inconsistency for non-hybrid bidirectional usage.
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py New FSDP2 buffer protocol. fsdp_extract_buffers strips scale alignment padding before gather; fsdp_assign_gathered re-pads. Previously-flagged floor-division bug for columnwise scale truncation (line 366) is still present.
transformer_engine/pytorch/module/grouped_linear.py _is_hybrid_quantizer_list rejects None+Hybrid mixed lists (fixes prior review concern); _hybrid_split_quantize correctly dispatches two tex.split_quantize calls and zips results into HybridStorage.
transformer_engine/pytorch/tensor/float8_tensor.py aten.split dispatch correctly handles _data=None (columnwise-only): infers num_splits from t_func_out, uses transpose shape to infer shard shape. aten.view and _ViewFunc similarly guard for _data=None.
transformer_engine/pytorch/distributed.py Hybrid branch in gather_along_first_dim correctly saves and restores prev_row/prev_col usage flags in try/finally; re-quantizes with both directions after AG to compensate for the missing _create_transpose synthesis path.
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu Null-pointer guards added for output_c and tile_scales_inv_c enable columnwise-only mode; out_dtype selection picks the correct FP8 dtype when only the transpose is requested.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    HQ[HybridQuantizer
rowwise_quantizer
columnwise_quantizer
columnwise_source] -->|quantize_impl| HQT[HybridQuantizedTensor
_rowwise_storage
_columnwise_storage]

    HQT -->|GEMM dispatch| UA[_unwrap_hybrid_A
layout0==T → rowwise
layout0==N → columnwise]
    HQT -->|GEMM dispatch| UB[_unwrap_hybrid_B
layout1==N → rowwise
layout1==T → columnwise]
    UA --> UI1[_unwrap_identity_tensor
IdentityStorage → dequantize]
    UB --> UI2[_unwrap_identity_tensor
IdentityStorage → dequantize]
    UI1 --> GEMM[general_gemm / general_grouped_gemm]
    UI2 --> GEMM

    HQT -->|FSDP2 pre-AG| EPR[rowwise_storage
.fsdp_extract_buffers]
    HQT -->|FSDP2 pre-AG| EPC[columnwise_storage
.fsdp_extract_buffers]
    EPR --> AG[all_gather_into_tensor
per-direction buffers]
    EPC --> AG
    AG -->|fsdp_post_all_gather| RAG[fsdp_assign_gathered
+ _sync_usage]
    RAG --> HQT2[Reconstructed
HybridQuantizedTensor]

    HQ -->|GroupedLinear| HSQ[_hybrid_split_quantize
tex.split_quantize x2
row + col passes]
    HSQ --> HQS[HybridQuantizedTensorStorage
per-expert]

    subgraph Sub-storage types
        F8[Float8Tensor
delayed/current]
        MX[MXFP8Tensor]
        FBW[Float8BlockwiseQTensor]
        ID[IdentityTensor
high-precision]
    end
    HQT --> F8
    HQT --> MX
    HQT --> FBW
    HQT --> ID
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    HQ[HybridQuantizer
rowwise_quantizer
columnwise_quantizer
columnwise_source] -->|quantize_impl| HQT[HybridQuantizedTensor
_rowwise_storage
_columnwise_storage]

    HQT -->|GEMM dispatch| UA[_unwrap_hybrid_A
layout0==T → rowwise
layout0==N → columnwise]
    HQT -->|GEMM dispatch| UB[_unwrap_hybrid_B
layout1==N → rowwise
layout1==T → columnwise]
    UA --> UI1[_unwrap_identity_tensor
IdentityStorage → dequantize]
    UB --> UI2[_unwrap_identity_tensor
IdentityStorage → dequantize]
    UI1 --> GEMM[general_gemm / general_grouped_gemm]
    UI2 --> GEMM

    HQT -->|FSDP2 pre-AG| EPR[rowwise_storage
.fsdp_extract_buffers]
    HQT -->|FSDP2 pre-AG| EPC[columnwise_storage
.fsdp_extract_buffers]
    EPR --> AG[all_gather_into_tensor
per-direction buffers]
    EPC --> AG
    AG -->|fsdp_post_all_gather| RAG[fsdp_assign_gathered
+ _sync_usage]
    RAG --> HQT2[Reconstructed
HybridQuantizedTensor]

    HQ -->|GroupedLinear| HSQ[_hybrid_split_quantize
tex.split_quantize x2
row + col passes]
    HSQ --> HQS[HybridQuantizedTensorStorage
per-expert]

    subgraph Sub-storage types
        F8[Float8Tensor
delayed/current]
        MX[MXFP8Tensor]
        FBW[Float8BlockwiseQTensor]
        ID[IdentityTensor
high-precision]
    end
    HQT --> F8
    HQT --> MX
    HQT --> FBW
    HQT --> ID
Loading

Reviews (20): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/tensor/storage/hybrid_tensor_storage.py
Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py Outdated

@timmoon10 timmoon10 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think this moves us in a good direction. I see some minor bugs, as well as bugs reported by @greptile-apps.

Comment on lines +52 to +53
rowwise_result = self.rowwise_quantizer.quantize(tensor)
columnwise_result = self.columnwise_quantizer.quantize(tensor)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we handle the case where not all usages are needed? I'd expect something like:

Suggested change
rowwise_result = self.rowwise_quantizer.quantize(tensor)
columnwise_result = self.columnwise_quantizer.quantize(tensor)
rowwise_result = self.rowwise_quantizer.quantize(tensor) if self.rowwise_usage else None
columnwise_result = self.columnwise_quantizer.quantize(tensor) if self.columnwise_usage else None

@negvet negvet May 21, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 4858491

requires_grad: bool = False,
pin_memory: bool = False,
) -> HybridQuantizedTensor:
self.rowwise_quantizer.internal = True

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just set internal=True in the constructor? I don't think we ever need PyTorch tensor functionality in the per-usage data.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would not work under FSDP2.

Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py Outdated
Comment on lines +1339 to +1355
def factory(role):
if role == "linear_weight":
return HybridQuantizer(
rowwise_quantizer=_make_fp8_quantizer(),
columnwise_quantizer=_make_mxfp8_quantizer(),
)
if role == "linear_input":
return HybridQuantizer(
rowwise_quantizer=_make_fp8_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
if role in ("linear_grad_output", "linear_grad_input"):
return HybridQuantizer(
rowwise_quantizer=_make_mxfp8_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
return None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is horrifying. Good test.

negvet and others added 10 commits April 6, 2026 10:26
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py
negvet and others added 2 commits April 29, 2026 16:02
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Comment thread transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py
negvet added 3 commits May 13, 2026 12:34
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet negvet requested a review from ksivaman as a code owner May 21, 2026 13:53
Comment thread transformer_engine/pytorch/tensor/float8_tensor.py
Comment on lines +27 to +30
# DCP serializes ``CustomRecipe`` via ``pickle``; closure-based qfactories
# (lambdas, inner functions referencing captured state) are not picklable,
# so the qfactory must live at module scope. See
# ``run_fsdp2_fused_adam.py::test_hybrid_dcp_output_parity``.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is potentially useful, but I don't think it is in the right place - shouldn't it be closer to the actual implementation?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines +1177 to +1184
for param in model.parameters():
state = optimizer.state[param]
assert state["exp_avg"].dtype == torch.float32
assert state["exp_avg_sq"].dtype == torch.float32
if "master_param" in state:
assert state["master_param"].dtype == torch.float32

assert losses[-1] < losses[0], f"Loss did not decrease: {losses[0]:.4f} -> {losses[-1]:.4f}"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not a very strict test, is there a way for us to do some numerical correctness comparisons?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enabled check for the monotonic loss decrease (still mostly sanity), and also enabled hybrid vs vanilla bitwise recipe comparizon, see e.g. test_fused_adam_hybrid_vs_base_recipe_parity.

@negvet

negvet commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

negvet added 2 commits June 12, 2026 13:15
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet

negvet commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@negvet

negvet commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@negvet

negvet commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator Author

Enable columnwise_source and hybrid recipes
columnwise_source makes the origin of the columnwise operand explicit in HybridQuantizer (options are {"original", "rowwise_dequantized"}). With columnwise_source="rowwise_dequantized", the backward columnwise operand is built from dequantize(rowwise_fprop_quantized(x)), so backward sees the forward quantization error instead of re-reading the original high-precision tensor. The same mechanism also supports double-quantization, where colwise direction is quantized from the dequantized rowwise (if colwise quantizer in HybridQuantizer is a non-Identity quantizer).

Respect quantizer veto for save_original_inp
save_original_input is now treated as an optimization hint that can be rejected by the quantizer if it would violate recipe semantics. Hybrid quantizers that require the forward-quantized value can force the save-forward path.

negvet and others added 3 commits June 23, 2026 07:33
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Comment on lines +365 to +370
elif name == "_columnwise_scale_inv" and t is not None:
expected = flattened_in_shape0 // MXFP8_BLOCK_SCALING_SIZE
if t.size(0) != expected:
t = t[:expected]
buffers.append(t)
return tuple(buffers), {"field_names": names}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 The columnwise scale truncation uses floor division (flattened_in_shape0 // MXFP8_BLOCK_SCALING_SIZE) instead of ceiling. For a sharded tensor where M is not a multiple of 32, ceil(M/32) scale entries are valid but M//32 are retained — the entry covering the last partial block is silently dropped. After all-gather, dequantization for those boundary rows uses a stale or zero scale. For example with M=48: 2 scale entries valid, but 48//32=1 is used, discarding row 32–47's scale.

Suggested change
elif name == "_columnwise_scale_inv" and t is not None:
expected = flattened_in_shape0 // MXFP8_BLOCK_SCALING_SIZE
if t.size(0) != expected:
t = t[:expected]
buffers.append(t)
return tuple(buffers), {"field_names": names}
elif name == "_columnwise_scale_inv" and t is not None:
expected = math.ceil(flattened_in_shape0 / MXFP8_BLOCK_SCALING_SIZE)
if t.size(0) != expected:
t = t[:expected]
buffers.append(t)
return tuple(buffers), {"field_names": names}

Signed-off-by: Evgeny <etsykunov@nvidia.com>

@kwyss-nvidia kwyss-nvidia left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Evgeny for this expansive PR!

I'm excited to see the columnwise_source options in hybrid quantizer and that the edge cases for the FSDP protocol are considered and captured in the new tensor types.

LGTM!

assert copied.rowwise_usage is False
assert copied.columnwise_usage is True

def test_rowwise_dequantized_identity_columnwise_matches_rowwise(self, input_tensor):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the coverage for double quantization and the field so that the quantized tensor tracks describes columnwise source.

# ---------------------------------------------------------------------------


# Module-level qfactories (picklable, required for checkpoint serialization).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which qfactories go into the checkpoint serialization?

While it's useful to have provenance of how the checkpoint was created, does the pickling of qfactories mean that the resulting checkpoints won't be read by transformer engine's with the same classes for custom quantization.

Is there any way to override this requirement and load the checkpoint as BF16, ignoring the pickled qfactories?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I clarified the comment: the module-level qfactory is only needed so TE-to-TE quantized-param checkpoints have a stable importable reference for any pickled quantizer/recipe metadata.

For portability: if quantized_model_init is disabled, model weights are normal BF16 tensors and the CustomRecipe extra state is not needed for stateless recipes, so an external consumer can ignore TE _extra_state. With quantized_model_init, the model state_dict stores TE quantized tensor subclasses for any recipe, not just hybrid/CustomRecipe, so TE -> third-party runtime portability would need a separate high-precision/BF16 export path for quantized primary weights.

Would it be useful to plan that quantized_model_init BF16-weight state_dict/export support as a follow-up, or is running without quantized_model_init for portable checkpoints sufficient for your use case?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running without quantized_model_init is sufficient. I did not know about that option!

transa = layout[0] == "T"
transb = layout[1] == "T"

A = _materialize_high_precision(_unwrap_hybrid_A(A, layout))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming convention could be revisited. It can be interpreted easily but falsely to mean all quantized tensors will be materialized into high precision tensors. Perhaps "_unwrap_if_high_precision"?

@negvet negvet Jun 30, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced with _unwrap_identity_tensor, since we are doing isinstance(tensor, IdentityTensorStorage)


# Linear-only recipe (no attention quantization): the qfactory is the only knob.
recipe = CustomRecipe(qfactory=mxfp8_fwd_nvfp4_bwd_quantizer_factory)
with autocast(recipe=recipe):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pleasantly simple as an API. Thanks Evgeny.

``HybridQuantizer`` terms, that source choice is expressed with
``columnwise_source="rowwise_dequantized"``.

All non-weight roles keep the standard NVFP4 factory behavior, including RHT

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is useful. We have trained with an equivalent recipe for several experiments and I'm looking forward to trying this implementation.

# Return early if recipe state matches recipe
if self.fp8_meta_tensors_initialized:
recipe_state = self.fp8_meta[fp8_meta_tensor_key]
# Follow-up: Match built-in recipes by full config, not just RecipeState type, so

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this follow up liked an issue or this pull request ID, it would be easier to grep for all related follow ups.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create an umbrella tracker #3158 and referenced it

-------
MXFP8 forward plus high-precision backward from the rowwise-dequantized
forward value can be expressed as::

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the zoo, there's an example where the weight tensor is specialized with double quantization. It would be helpful to illustrate that it's possible to customize along the weight/activation/grad axis as well as the rowwise/colwise abstraction and reference the zoo.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, fixed 8c23fed

for sub in (self.rowwise_quantizer, self.columnwise_quantizer):
group = getattr(sub, "amax_reduction_group", None)
if group is not None:
return group

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arguably, this should assert if there are two groups, they are consistent. Is that possible?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, fixed 8c23fed

Signed-off-by: Evgeny <etsykunov@nvidia.com>
negvet and others added 2 commits June 30, 2026 13:47
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants